Machine learning
Setup the environment if this is executed on Google Colab.
Make sure to change the runtime type to GPU. To do this go to Runtime -> Change runtime type -> GPU
Otherwise, rendering won't work in Google Colab.
import os
try:
import google.colab
IN_COLAB = True
except:
IN_COLAB = False
if IN_COLAB:
os.system("pip install --quiet 'x_xy[all_muj] @ git+https://github.com/SimiPixel/x_xy_v2'")
os.system("pip install --quiet mediapy")
import x_xy
# automatically detects colab or not
x_xy.utils.setup_colab_env()
from x_xy.subpkgs import ml, exp, sys_composer, sim2real
import mediapy
import jax.numpy as jnp
import tree_utils
def load_data_and_prediction(motion, sys, params):
exp_data = exp.load_data("S_04", motion)
xml_str = exp.load_xml_str("S_04")
xs = sim2real.xs_from_raw(sys, exp.link_name_pos_rot_data(exp_data, xml_str), qinv=True)
# slightly decrease `transform1.pos.x` by a little; purely for better optics
translations, rotations = sim2real.unzip_xs(sys, xs)
seg_mask = jnp.array([sys.name_to_idx(name) for name in sys.link_names[1:] if name[:3] != "imu"])
imu_mask = jnp.array([sys.name_to_idx(name) for name in sys.link_names[1:] if name[:3] == "imu"])
translations = translations.replace(pos=translations.pos.at[:, seg_mask, 0].set(translations.pos[:, seg_mask, 0] - 0.03))
translations = translations.replace(pos=translations.pos.at[:, imu_mask, 0].set(translations.pos[:, imu_mask, 0] + 0.03))
xs_translated = sim2real.zip_xs(sys, translations, rotations)
X = {seg: {} for seg in ["seg2", "seg3", "seg4"]}
for seg in X:
imu_data = exp_data[seg]["imu_rigid"]
imu_data.pop("mag")
if seg == "seg3":
imu_data = tree_utils.tree_zeros_like(imu_data)
X[seg].update(imu_data)
sys_noimu, _ = sys_composer.make_sys_noimu(sys)
filter = ml.RNNOFilter(params=params)
filter.init(sys_noimu, tree_utils.tree_slice(X, 0))
yhat = tree_utils.tree_slice(filter.predict(tree_utils.add_batch_dim(X)), 0)
return xs_translated, yhat
params = ml.load(pretrained="rr_rr_unknown")
motion = "thomas_fast"
sys = exp.load_sys("S_04", morph_yaml_key="seg2", delete_after_morph=["seg5", "imu3"])
xs, yhat = load_data_and_prediction(motion, sys, params)
frames = x_xy.render_prediction(sys, xs, yhat, stepframe=4, width=640, height=480, camera="c",
add_cameras={-1: '<camera name="c" mode="targetbody" target="3" pos=".5 -.5 1.25"/>',})
mediapy.show_video(frames, fps=25.0)